Skip to content

Conversation

badumbatish
Copy link
Contributor

Fixes #50154

@llvmbot
Copy link
Member

llvmbot commented Aug 1, 2025

@llvm/pr-subscribers-backend-webassembly

Author: Jasmine Tang (badumbatish)

Changes

Fixes #50154


Full diff: https://github.com/llvm/llvm-project/pull/151775.diff

2 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+51)
  • (added) llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll (+21)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine wide-vector muls, with extend inputs, to extmul_half.
     setTargetDAGCombine(ISD::MUL);
 
+    // Combine add with vector shuffle of muls to dots
+    setTargetDAGCombine(ISD::ADD);
+
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::ADD);
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+  if (VT != MVT::v4i32)
+    return SDValue();
+
+  auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+    if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+      return SDValue();
+    if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+      return SDValue();
+    return V;
+  };
+  auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+  auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+  // two SDValues must be muls
+  if (!ShuffleA || !ShuffleB)
+    return SDValue();
+
+  if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+      ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+    return SDValue();
+
+  auto IsMulExtend =
+      [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+    if (V.getOpcode() != ISD::MUL)
+      return {};
+
+    auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+    if (V0.getOpcode() != I || V1.getOpcode() != I)
+      return {};
+    return {V0.getOperand(0), V1.getOperand(0)};
+  };
+
+  auto [LowA, LowB] =
+      IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+  auto [HighA, HighB] =
+      IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+  if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+    return SDValue();
+
+  return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
+  case ISD::ADD:
+    return performAddCombine(N, DCI.DAG);
   }
 }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..7ac49794491a1
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,21 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK:         .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.dot_i16x8_s
+; CHECK-NEXT:    # fallthrough-return
+  %sext1 = sext <8 x i16> %a to <8 x i32>
+  %sext2 = sext <8 x i16> %b to <8 x i32>
+  %mul = mul nsw <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %res = add <4 x i32> %shuffle1, %shuffle2
+  ret <4 x i32> %res
+}
+

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!

@badumbatish
Copy link
Contributor Author

Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!

I actually didn't think that tablegen can handle identical arguments in a pattern, i just tried it out just now and i think i might be able to make it work

@sparker-arm
Copy link
Contributor

i think i might be able to make it work

I'm assuming the 'illegal' types will make this more difficult in tablegen? This approach looks good to me.

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title needs updated to reflect that this isn't a combine but adds a foldpattern

EDIT: Typo, sorry!

@lukel97
Copy link
Contributor

lukel97 commented Aug 6, 2025

i think i might be able to make it work

I'm assuming the 'illegal' types will make this more difficult in tablegen? This approach looks good to me.

Good point. It looks like the old combine only operated on MVT::v8i16s feeding into MVT::v4i32 adds, but I presume we also want to handle i8s that are sext'd to i32, e.g:

define <4 x i32> @f(<8 x i8> %a, <8 x i8> %b) {
  %sext1 = sext <8 x i8> %a to <8 x i32>
   %sext2 = sext <8 x i8> %b to <8 x i32>
  %mul = mul nsw <8 x i32> %sext1, %sext2
  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
  %res = add <4 x i32> %shuffle1, %shuffle2
  ret <4 x i32> %res
}

And from a quick check it looks the multiply is hoisted into the narrower type:

Optimized legalized selection DAG: %bb.0 'f:'
SelectionDAG has 46 nodes:
      t2: v16i8 = WebAssemblyISD::ARGUMENT TargetConstant:i32<0>
    t40: v8i16 = WebAssemblyISD::EXTEND_LOW_S t2
      t4: v16i8 = WebAssemblyISD::ARGUMENT TargetConstant:i32<1>
    t39: v8i16 = WebAssemblyISD::EXTEND_LOW_S t4
  t35: v8i16 = mul t40, t39
  t37: v4i32 = WebAssemblyISD::EXTEND_HIGH_S t35
  t36: v4i32 = WebAssemblyISD::EXTEND_LOW_S t35
    t0: ch,glue = EntryToken
      t77: v4i32 = WebAssemblyISD::SHUFFLE t36, t37, Constant:i32<0>, Constant:i32<1>, Constant:i32<2>, Constant:i32<3>, Constant:i32<8>, Constant:i32<9>, Constant:i32<10>, Constant:i32<11>, Constant:i32<16>, Constant:i32<17>, Constant:i32<18>, Constant:i32<19>, Constant:i32<24>, Constant:i32<25>, Constant:i32<26>, Constant:i32<27>
      t60: v4i32 = WebAssemblyISD::SHUFFLE t36, t37, Constant:i32<4>, Constant:i32<5>, Constant:i32<6>, Constant:i32<7>, Constant:i32<12>, Constant:i32<13>, Constant:i32<14>, Constant:i32<15>, Constant:i32<20>, Constant:i32<21>, Constant:i32<22>, Constant:i32<23>, Constant:i32<28>, Constant:i32<29>, Constant:i32<30>, Constant:i32<31>
    t30: v4i32 = add t77, t60
  t31: ch = WebAssemblyISD::RETURN t0, t30

We could add a second tablegen pattern for (add (shuffle (sext mul)) (shuffle (sext mul))).

Or we could try and do this again as a dagcombine, but it looks like the multiply is already hoisted by the time the types are legalized so we would have to handle the two different patterns anyway:

Type-legalized selection DAG: %bb.0 'f:'
SelectionDAG has 26 nodes:
      t2: v16i8 = WebAssemblyISD::ARGUMENT TargetConstant:i32<0>
    t40: v8i16 = WebAssemblyISD::EXTEND_LOW_S t2
      t4: v16i8 = WebAssemblyISD::ARGUMENT TargetConstant:i32<1>
    t39: v8i16 = WebAssemblyISD::EXTEND_LOW_S t4
  t35: v8i16 = mul t40, t39
  t36: v4i32 = WebAssemblyISD::EXTEND_LOW_S t35
  t37: v4i32 = WebAssemblyISD::EXTEND_HIGH_S t35
    t0: ch,glue = EntryToken
        t12: i32 = extract_vector_elt t36, Constant:i32<0>
        t14: i32 = extract_vector_elt t36, Constant:i32<2>
        t16: i32 = extract_vector_elt t37, Constant:i32<0>
        t18: i32 = extract_vector_elt t37, Constant:i32<2>
      t19: v4i32 = BUILD_VECTOR t12, t14, t16, t18
        t21: i32 = extract_vector_elt t36, Constant:i32<1>
        t23: i32 = extract_vector_elt t36, Constant:i32<3>
        t25: i32 = extract_vector_elt t37, Constant:i32<1>
        t27: i32 = extract_vector_elt t37, Constant:i32<3>
      t28: v4i32 = BUILD_VECTOR t21, t23, t25, t27
    t30: v4i32 = add t19, t28
  t31: ch = WebAssemblyISD::RETURN t0, t30

@badumbatish
Copy link
Contributor Author

badumbatish commented Aug 7, 2025

godbolt for this https://godbolt.org/z/Y1rrnW5h3. The IR at the isel phase looks a bit different. Would you want me to try that in this PR as well?

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the PR title updated

I think we can handle the zext from i8 case in a separate PR if the pattern ends up being somewhat different anyway, but I'll defer to @sparker-arm on this! (I'm not strongly opinionated on this as to if we want to do this as a combine or tablegen pattern)

@badumbatish badumbatish changed the title [WebAssembly] Add fold support for dot [WebAssembly] Add extra pattern for dot Aug 11, 2025
@badumbatish
Copy link
Contributor Author

hi @sparker-arm, i'm looking to add the somewhat same pattern to relaxed simd dot, can you confirm if this approach is good for merge and for the subsequent pr?

Copy link
Contributor

@sparker-arm sparker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't know you were waiting on me, please feel free to shout sooner next time!

@badumbatish
Copy link
Contributor Author

no worries, i'll keep that in mind next time, ty for the reviews!

@badumbatish badumbatish merged commit 55d4e92 into llvm:main Oct 13, 2025
9 checks passed
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[SIMD] pattern for i32x4.dot_i16x8_s not recognized

4 participants